{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 7.2 torchtextでのDataset、DataLoaderの実装方法\n",
    "\n",
    "- 本ファイルでは、torchtextを使用してDatasetおよびDataLoaderを実装する方法を解説します。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "※　本章のファイルはすべてUbuntuでの動作を前提としています。Windowsなど文字コードが違う環境での動作にはご注意下さい。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 7.2 学習目標\n",
    "\n",
    "1.\ttorchtextを用いてDatasetおよびDataLoaderの実装ができる"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 事前準備\n",
    "\n",
    "- 書籍の指示に従い、本章で使用するデータを用意します\n",
    "\n",
    "- torchtextをインストールします\n",
    "\n",
    "- pip install torchtext\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 . 前処理と単語分割の関数を実装\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 単語分割にはJanomeを使用\n",
    "from janome.tokenizer import Tokenizer\n",
    "\n",
    "j_t = Tokenizer()\n",
    "\n",
    "\n",
    "def tokenizer_janome(text):\n",
    "    return [tok for tok in j_t.tokenize(text, wakati=True)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 前処理として正規化をする関数を定義\n",
    "import re\n",
    "\n",
    "\n",
    "def preprocessing_text(text):\n",
    "    # 半角・全角の統一\n",
    "    # 今回は無視\n",
    "\n",
    "    # 英語の小文字化\n",
    "    # 今回はここでは無視\n",
    "    # output = output.lower()\n",
    "\n",
    "    # 改行、半角スペース、全角スペースを削除\n",
    "    text = re.sub('\\r', '', text)\n",
    "    text = re.sub('\\n', '', text)\n",
    "    text = re.sub('　', '', text)\n",
    "    text = re.sub(' ', '', text)\n",
    "\n",
    "    # 数字文字の一律「0」化\n",
    "    text = re.sub(r'[0-9 ０-９]', '0', text)  # 数字\n",
    "\n",
    "    # 記号と数字の除去\n",
    "    # 今回は無視。半角記号,数字,英字\n",
    "    # 今回は無視。全角記号\n",
    "\n",
    "    # 特定文字を正規表現で置換する\n",
    "    # 今回は無視\n",
    "\n",
    "    return text\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['昨日', 'は', 'とても', '暑く', '、', '気温', 'が', '00', '度', 'も', 'あっ', 'た', '。']\n"
     ]
    }
   ],
   "source": [
    "# 前処理とJanomeの単語分割を合わせた関数を定義する\n",
    "\n",
    "\n",
    "def tokenizer_with_preprocessing(text):\n",
    "    text = preprocessing_text(text)  # 前処理の正規化\n",
    "    ret = tokenizer_janome(text)  # Janomeの単語分割\n",
    "\n",
    "    return ret\n",
    "\n",
    "\n",
    "# 動作確認\n",
    "text = \"昨日は とても暑く、気温が36度もあった。\"\n",
    "print(tokenizer_with_preprocessing(text))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 文章データの読み込み"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchtext\n",
    "\n",
    "# tsvやcsvデータを読み込んだときに、読み込んだ内容に対して行う処理を定義します\n",
    "# 文章とラベルの両方に用意します\n",
    "\n",
    "max_length = 25\n",
    "TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer_with_preprocessing,\n",
    "                            use_vocab=True, lower=True, include_lengths=True, batch_first=True, fix_length=max_length)\n",
    "LABEL = torchtext.data.Field(sequential=False, use_vocab=False)\n",
    "\n",
    "# 引数の意味は次の通り\n",
    "# sequential: データの長さが可変か？文章は長さがいろいろなのでTrue.ラベルはFalse\n",
    "# tokenize: 文章を読み込んだときに、前処理や単語分割をするための関数を定義\n",
    "# use_vocab：単語をボキャブラリー（単語集：後で解説）に追加するかどうか\n",
    "# lower：アルファベットがあったときに小文字に変換するかどうか\n",
    "# include_length: 文章の単語数のデータを保持するか\n",
    "# batch_first：ミニバッチの次元を先頭に用意するかどうか\n",
    "# fix_length：全部の文章を指定した長さと同じになるように、paddingします\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "訓練データの数 4\n",
      "1つ目の訓練データ {'Text': ['王', 'と', '王子', 'と', '女王', 'と', '姫', 'と', '男性', 'と', '女性', 'が', 'い', 'まし', 'た', '。'], 'Label': '0'}\n",
      "2つ目の訓練データ {'Text': ['機械', '学習', 'が', '好き', 'です', '。'], 'Label': '1'}\n"
     ]
    }
   ],
   "source": [
    "# data.TabularDataset 詳細\n",
    "# https://torchtext.readthedocs.io/en/latest/examples.html?highlight=data.TabularDataset.splits\n",
    "\n",
    "# フォルダ「data」から各tsvファイルを読み込み、Datasetにします\n",
    "# 1行がTEXTとLABELで区切られていることをfieldsで指示します\n",
    "train_ds, val_ds, test_ds = torchtext.data.TabularDataset.splits(\n",
    "    path='./data/', train='text_train.tsv',\n",
    "    validation='text_val.tsv', test='text_test.tsv', format='tsv',\n",
    "    fields=[('Text', TEXT), ('Label', LABEL)])\n",
    "\n",
    "\n",
    "# 動作確認\n",
    "print('訓練データの数', len(train_ds))\n",
    "print('1つ目の訓練データ', vars(train_ds[0]))\n",
    "print('2つ目の訓練データ', vars(train_ds[1]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 単語の数値化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'王': 1,\n",
       "         'と': 5,\n",
       "         '王子': 1,\n",
       "         '女王': 1,\n",
       "         '姫': 1,\n",
       "         '男性': 1,\n",
       "         '女性': 1,\n",
       "         'が': 3,\n",
       "         'い': 1,\n",
       "         'まし': 1,\n",
       "         'た': 1,\n",
       "         '。': 4,\n",
       "         '機械': 1,\n",
       "         '学習': 1,\n",
       "         '好き': 1,\n",
       "         'です': 1,\n",
       "         '本章': 2,\n",
       "         'から': 1,\n",
       "         '自然': 1,\n",
       "         '言語': 1,\n",
       "         '処理': 1,\n",
       "         'に': 1,\n",
       "         '取り組み': 1,\n",
       "         'ます': 2,\n",
       "         'で': 1,\n",
       "         'は': 1,\n",
       "         '商品': 1,\n",
       "         'レビュー': 1,\n",
       "         'の': 4,\n",
       "         '短い': 1,\n",
       "         '文章': 4,\n",
       "         'に対して': 1,\n",
       "         '、': 3,\n",
       "         'その': 1,\n",
       "         'ネガティブ': 1,\n",
       "         'な': 4,\n",
       "         '評価': 2,\n",
       "         'を': 3,\n",
       "         'し': 3,\n",
       "         'て': 2,\n",
       "         'いる': 2,\n",
       "         'か': 2,\n",
       "         'ポジティブ': 1,\n",
       "         '0': 1,\n",
       "         '値': 1,\n",
       "         'クラス': 1,\n",
       "         '分類': 2,\n",
       "         'する': 1,\n",
       "         'モデル': 1,\n",
       "         '構築': 1})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ボキャブラリーを作成します\n",
    "# 訓練データtrainの単語からmin_freq以上の頻度の単語を使用してボキャブラリー（単語集）を構築\n",
    "TEXT.build_vocab(train_ds, min_freq=1)\n",
    "\n",
    "# 訓練データ内の単語と頻度を出力(頻度min_freqより大きいものが出力されます)\n",
    "TEXT.vocab.freqs  # 出力させる\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(<function torchtext.vocab._default_unk_index()>,\n",
       "            {'<unk>': 0,\n",
       "             '<pad>': 1,\n",
       "             'と': 2,\n",
       "             '。': 3,\n",
       "             'な': 4,\n",
       "             'の': 5,\n",
       "             '文章': 6,\n",
       "             '、': 7,\n",
       "             'が': 8,\n",
       "             'し': 9,\n",
       "             'を': 10,\n",
       "             'いる': 11,\n",
       "             'か': 12,\n",
       "             'て': 13,\n",
       "             'ます': 14,\n",
       "             '分類': 15,\n",
       "             '本章': 16,\n",
       "             '評価': 17,\n",
       "             '0': 18,\n",
       "             'い': 19,\n",
       "             'から': 20,\n",
       "             'する': 21,\n",
       "             'その': 22,\n",
       "             'た': 23,\n",
       "             'で': 24,\n",
       "             'です': 25,\n",
       "             'に': 26,\n",
       "             'に対して': 27,\n",
       "             'は': 28,\n",
       "             'まし': 29,\n",
       "             'クラス': 30,\n",
       "             'ネガティブ': 31,\n",
       "             'ポジティブ': 32,\n",
       "             'モデル': 33,\n",
       "             'レビュー': 34,\n",
       "             '値': 35,\n",
       "             '処理': 36,\n",
       "             '取り組み': 37,\n",
       "             '商品': 38,\n",
       "             '女性': 39,\n",
       "             '女王': 40,\n",
       "             '好き': 41,\n",
       "             '姫': 42,\n",
       "             '学習': 43,\n",
       "             '構築': 44,\n",
       "             '機械': 45,\n",
       "             '王': 46,\n",
       "             '王子': 47,\n",
       "             '男性': 48,\n",
       "             '短い': 49,\n",
       "             '自然': 50,\n",
       "             '言語': 51})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ボキャブラリーの単語をidに変換した結果を出力。\n",
    "# 頻度がmin_freqより小さい場合は未知語<unk>になる\n",
    "\n",
    "TEXT.vocab.stoi  # 出力。string to identifiers 文字列をidへ\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DataLoaderの作成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([[46,  2, 47,  2, 40,  2, 42,  2, 48,  2, 39,  8, 19, 29, 23,  3,  1,  1,\n",
      "          1,  1,  1,  1,  1,  1,  1],\n",
      "        [45, 43,  8, 41, 25,  3,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
      "          1,  1,  1,  1,  1,  1,  1]]), tensor([16,  6]))\n",
      "tensor([0, 1])\n"
     ]
    }
   ],
   "source": [
    "# DataLoaderを作成します（torchtextの文脈では単純にiteraterと呼ばれています）\n",
    "train_dl = torchtext.data.Iterator(train_ds, batch_size=2, train=True)\n",
    "\n",
    "val_dl = torchtext.data.Iterator(\n",
    "    val_ds, batch_size=2, train=False, sort=False)\n",
    "\n",
    "test_dl = torchtext.data.Iterator(\n",
    "    test_ds, batch_size=2, train=False, sort=False)\n",
    "\n",
    "\n",
    "# 動作確認 検証データのデータセットで確認\n",
    "batch = next(iter(val_dl))\n",
    "print(batch.Text)\n",
    "print(batch.Label)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
